{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"https://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\" width=\"90px\">\n",
    "\n",
    "# PySpark LLM Inference: Qwen-2.5-14b Data Structuring\n",
    "\n",
    "In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct), using open weights on Huggingface.\n",
    "\n",
    "The Qwen-2.5-14b-instruct is an instruction-fine-tuned version of the Qwen-2.5-14b base model. We'll show how to use the model to prepare unstructured text data into a structured schema for downstream tasks.\n",
    "\n",
    "**Note:** This example demonstrates **tensor parallelism**, which requires multiple GPUs per node. For standalone users, make sure to use a Spark worker with 2 GPUs. If you follow the Databricks or Dataproc instructions, make sure to include the `tp` argument to the cluster startup scripts."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check the cluster environment to handle any platform-specific configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
    "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
    "on_standalone = not (on_databricks or on_dataproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# For cloud environments, load the model to the distributed file system.\n",
    "if on_databricks:\n",
    "    models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
    "    model_path = f\"{models_dir}/qwen2.5-14b\"\n",
    "elif on_dataproc:\n",
    "    models_dir = \"/mnt/gcs/spark-dl-models\"\n",
    "    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
    "    model_path = f\"{models_dir}/qwen2.5-14b\"\n",
    "else:\n",
    "    model_path = os.path.abspath(\"qwen2.5-14b\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the model from huggingface hub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f75ef1f2071f413da5ae502589293c62",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "model_path = snapshot_download(\n",
    "    repo_id=\"Qwen/Qwen2.5-14B-Instruct\",\n",
    "    local_dir=model_path\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PySpark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pyspark.sql.types import *\n",
    "from pyspark import SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import pandas_udf, col, struct, length, lit, concat\n",
    "from pyspark.ml.functions import predict_batch_udf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import datasets\n",
    "from datasets import load_dataset\n",
    "datasets.disable_progress_bars()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create Spark Session\n",
    "\n",
    "For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \n",
    "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "25/03/20 00:55:18 INFO SparkEnv: Registering MapOutputTracker\n",
      "25/03/20 00:55:18 INFO SparkEnv: Registering BlockManagerMaster\n",
      "25/03/20 00:55:18 INFO SparkEnv: Registering BlockManagerMasterHeartbeat\n",
      "25/03/20 00:55:18 INFO SparkEnv: Registering OutputCommitCoordinator\n"
     ]
    }
   ],
   "source": [
    "conf = SparkConf()\n",
    "\n",
    "if 'spark' not in globals():\n",
    "    if on_standalone:\n",
    "        import socket\n",
    "        conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
    "        hostname = socket.gethostname()\n",
    "        conf.setMaster(f\"spark://{hostname}:7077\")\n",
    "        conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
    "        conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
    "        \n",
    "    # For standalone users: adjust executor.cores and task.resource.gpu.amount based on available cores\n",
    "    conf.set(\"spark.executor.cores\", \"24\")  \n",
    "    conf.set(\"spark.task.maxFailures\", \"1\")\n",
    "    conf.set(\"spark.task.resource.gpu.amount\", \"0.083333\")\n",
    "    conf.set(\"spark.executor.resource.gpu.amount\", \"2\")  # 2 GPUs per executor\n",
    "    conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
    "    conf.set(\"spark.python.worker.reuse\", \"true\")\n",
    "\n",
    "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
    "sc = spark.sparkContext"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load and Preprocess DataFrame\n",
    "\n",
    "Load the first 500 samples of the [Amazon Video Game Product Reviews dataset](https://huggingface.co/datasets/logankells/amazon_product_reviews_video_games) from Huggingface and store in a Spark Dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "adead9f390a042a286de64a03a88fca8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/6.00 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Repo card metadata block was not found. Setting CardData to empty.\n"
     ]
    }
   ],
   "source": [
    "product_reviews_ds = load_dataset(\"LoganKells/amazon_product_reviews_video_games\", split=\"train\", streaming=True)\n",
    "product_reviews_pds = pd.Series([sample[\"reviewText\"] for sample in product_reviews_ds.take(500)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df = spark.createDataFrame(product_reviews_pds, schema=StringType())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------------------------------------------------------------------------------------+\n",
      "|                                                                                               value|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "|Installing the game was a struggle (because of games for windows live bugs).Some championship rac...|\n",
      "|If you like rally cars get this game you will have fun.It is more oriented to &#34;European marke...|\n",
      "|1st shipment received a book instead of the game.2nd shipment got a FAKE one. Game arrived with a...|\n",
      "|I had Dirt 2 on Xbox 360 and it was an okay game. I started playing games on my laptop and bought...|\n",
      "|Overall this is a well done racing game, with very good graphics for its time period. My family h...|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show(5, truncate=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Format each sample into the Qwen chat template, including a system prompt to guide generation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "system_prompt = \"\"\"You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.\n",
    "IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.\n",
    "For each review, analyze and output EXACTLY this JSON structure:\n",
    "{\n",
    "  \"primary_sentiment\": [EXACTLY ONE OF: \"positive\", \"negative\", \"neutral\", \"mixed\"],\n",
    "  \"sentiment_score\": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],\n",
    "  \"purchase_intention\": [EXACTLY ONE OF: \"will repurchase\", \"might repurchase\", \"will not repurchase\", \"recommends alternatives\", \"uncertain\"]\n",
    "}\n",
    "\n",
    "Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.\n",
    "\"\"\"\n",
    "\n",
    "df = df.select(\n",
    "    concat(\n",
    "        lit(\"<|im_start|>system\\n\"),\n",
    "        lit(system_prompt),\n",
    "        lit(\"<|im_end|>\\n<|im_start|>user\\n\"),\n",
    "        lit(\"Analyze this review: \"),\n",
    "        col(\"value\"),\n",
    "        lit(\"<|im_end|>\\n<|im_start|>assistant\\n\")\n",
    "    ).alias(\"prompt\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|im_start|>system\n",
      "You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.\n",
      "IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.\n",
      "For each review, analyze and output EXACTLY this JSON structure:\n",
      "{\n",
      "  \"primary_sentiment\": [EXACTLY ONE OF: \"positive\", \"negative\", \"neutral\", \"mixed\"],\n",
      "  \"sentiment_score\": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],\n",
      "  \"purchase_intention\": [EXACTLY ONE OF: \"will repurchase\", \"might repurchase\", \"will not repurchase\", \"recommends alternatives\", \"uncertain\"]\n",
      "}\n",
      "\n",
      "Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.\n",
      "<|im_end|>\n",
      "<|im_start|>user\n",
      "Analyze this review: Installing the game was a struggle (because of games for windows live bugs).Some championship races and cars can only be \"unlocked\" by buying them as an addon to the game. I paid nearly 30 dollars when the game was new. I don't like the idea that I have to keep paying to keep playing.I noticed no improvement in the physics or graphics compared to Dirt 2.I tossed it in the garbage and vowed never to buy another codemasters game. I'm really tired of arcade style rally/racing games anyway.I'll continue to get my fix from Richard Burns Rally, and you should to. :)http://www.amazon.com/Richard-Burns-Rally-PC/dp/B000C97156/ref=sr_1_1?ie=UTF8&qid;=1341886844&sr;=8-1&keywords;=richard+burns+rallyThank you for reading my review! If you enjoyed it, be sure to rate it as helpful.<|im_end|>\n",
      "<|im_start|>assistant\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(df.take(1)[0].prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "data_path = \"spark-dl-datasets/amazon_video_game_reviews\"\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
    "    data_path = \"dbfs:/FileStore/\" + data_path\n",
    "\n",
    "df.write.mode(\"overwrite\").parquet(data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using vLLM Server\n",
    "In this section, we demonstrate integration with [vLLM Serving](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), an open-source server with an OpenAI-compatible completions endpoint for LLMs.  \n",
    "\n",
    "The process looks like this:\n",
    "- Distribute a server startup task across the Spark cluster, instructing each node to launch a vLLM server process.\n",
    "- Define a vLLM inference function, which sends inference request to the local server on a given node.\n",
    "- Wrap the vLLM inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
    "- Finally, distribute a shutdown signal to terminate the vLLM server processes on each node.\n",
    "\n",
    "<img src=\"../images/spark-server-mg.png\" alt=\"drawing\" width=\"700\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from functools import partial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import the helper class from server_utils.py:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sc.addPyFile(\"server_utils.py\")\n",
    "\n",
    "from server_utils import VLLMServerManager"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are currently some hanging issues with vLLM's `torch.compile` on Databricks, which we are working to resolve. For now we will enforce eager mode on Databricks, which disables compilation at some performance cost."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enforce_eager = True if on_databricks else False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Start vLLM servers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `VLLMServerManager` will handle the lifecycle of vLLM server instances across the Spark cluster:\n",
    "- Find available ports for HTTP\n",
    "- Deploy a server on each node via stage-level scheduling\n",
    "- Gracefully shutdown servers across nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_name = \"qwen-2.5-14b\"\n",
    "server_manager = VLLMServerManager(model_name=model_name, model_path=model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can pass any of the supported [vLLM serve CLI arguments](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#vllm-serve) as key-word arguments when starting the servers. Note that this can take some time, as it includes loading the model from disk, Torch compilation, and capturing CUDA graphs.\n",
    "\n",
    "Here, we set `tensor_parallel_size` to the number of GPUs per node:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-20 01:04:42,978 - INFO - Requesting stage-level resources: (cores=13, gpu=2.0)\n",
      "2025-03-20 01:04:42,979 - INFO - Starting 2 VLLM servers.\n",
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'spark-dl-inference-vllm-tp-w-0': (35438, [7000]),\n",
       " 'spark-dl-inference-vllm-tp-w-1': (35288, [7000])}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor_parallel_size = int(spark.conf.get(\"spark.executor.resource.gpu.amount\"))\n",
    "server_manager.start_servers(tensor_parallel_size=tensor_parallel_size,\n",
    "                             gpu_memory_utilization=0.95,\n",
    "                             max_model_len=6600,\n",
    "                             task=\"generate\",\n",
    "                             enforce_eager=enforce_eager,\n",
    "                             wait_retries=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define client function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the hostname -> url mapping from the server manager:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "host_to_http_url = server_manager.host_to_http_url"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def vllm_fn(model_name, host_to_url):\n",
    "    import socket\n",
    "    import json\n",
    "    import requests\n",
    "\n",
    "    url = host_to_url[socket.gethostname()]\n",
    "    \n",
    "    def predict(inputs):\n",
    "        print(inputs)\n",
    "        response = requests.post(\n",
    "            \"http://localhost:7000/v1/completions\",\n",
    "            json={\n",
    "                \"model\": model_name,\n",
    "                \"prompt\": inputs.tolist(),\n",
    "                \"max_tokens\": 50,\n",
    "                \"temperature\": 0.7,\n",
    "                \"top_p\": 0.8,\n",
    "                \"repetition_penalty\": 1.05,\n",
    "            }\n",
    "        )\n",
    "        result_dicts = [json.loads(o[\"text\"]) for o in response.json()[\"choices\"]]\n",
    "        return result_dicts\n",
    "    \n",
    "    return predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "generate = predict_batch_udf(partial(vllm_fn, model_name=model_name, host_to_url=host_to_http_url),\n",
    "                             return_type=StructType([\n",
    "                                 StructField(\"primary_sentiment\", StringType()),\n",
    "                                 StructField(\"sentiment_score\", IntegerType()),\n",
    "                                 StructField(\"purchase_intention\", StringType())\n",
    "                             ]),\n",
    "                             batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "df = spark.read.parquet(data_path).repartition(16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 31:=================================================>      (14 + 2) / 16]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 29.6 ms, sys: 6.89 ms, total: 36.5 ms\n",
      "Wall time: 33 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# first pass caches model/fn\n",
    "preds = df.withColumn(\"outputs\", generate(col(\"prompt\"))).select(\"prompt\", \"outputs.*\")\n",
    "results = preds.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 34:=================================================>      (14 + 2) / 16]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 25.6 ms, sys: 6.73 ms, total: 32.3 ms\n",
      "Wall time: 32 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"outputs\", generate(col(\"prompt\"))).select(\"prompt\", \"outputs.*\")\n",
    "results = preds.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 37:>                                                         (0 + 1) / 1]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------------------------------------+-----------------+---------------+-------------------+\n",
      "|                                            prompt|primary_sentiment|sentiment_score| purchase_intention|\n",
      "+--------------------------------------------------+-----------------+---------------+-------------------+\n",
      "|<|im_start|>system\\nYou are a specialized revie...|         positive|              9|    will repurchase|\n",
      "|<|im_start|>system\\nYou are a specialized revie...|         positive|              9|    will repurchase|\n",
      "|<|im_start|>system\\nYou are a specialized revie...|         positive|              8|    will repurchase|\n",
      "|<|im_start|>system\\nYou are a specialized revie...|         negative|              4|will not repurchase|\n",
      "|<|im_start|>system\\nYou are a specialized revie...|            mixed|              6|   might repurchase|\n",
      "+--------------------------------------------------+-----------------+---------------+-------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "preds.show(5, truncate=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Review: <|im_start|>system\n",
      "You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.\n",
      "IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.\n",
      "For each review, analyze and output EXACTLY this JSON structure:\n",
      "{\n",
      "  \"primary_sentiment\": [EXACTLY ONE OF: \"positive\", \"negative\", \"neutral\", \"mixed\"],\n",
      "  \"sentiment_score\": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],\n",
      "  \"purchase_intention\": [EXACTLY ONE OF: \"will repurchase\", \"might repurchase\", \"will not repurchase\", \"recommends alternatives\", \"uncertain\"]\n",
      "}\n",
      "\n",
      "Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.\n",
      "<|im_end|>\n",
      "<|im_start|>user\n",
      "Analyze this review: I have never played anything like this since. Everything from Sly  Racoon, to Ratchet and Clank, owe it to this.Wicked witch Gruntilda takes Banjo's sister to hey layer, miles away in a realistic 3D cartoon world.Banjo is a bear with Kazooie a bird in his backpack that can help him jump and fly and basically you learn to do lots of things with it. You solve puzzles via action and collect tolkens across lovely maps. Mumbo Jumbo transforms Banjo into some other creatures along the way. You can fly. It was amazing. A full adventure all the way to end. We played it for months and I have NEVER played anything like it again. The makers of Donkey Kong released it at the best time. It is now up to the future generations to make adventure concepts better than this one. This is one of the best N64 games ever.<|im_end|>\n",
      "<|im_start|>assistant\n",
      "\n",
      "Sentiment: positive, Score: 9, Status: will repurchase\n"
     ]
    }
   ],
   "source": [
    "sample = results[0]\n",
    "print(\"Review:\", sample[\"prompt\"])\n",
    "print(f\"Sentiment: {sample['primary_sentiment']}, Score: {sample['sentiment_score']}, Status: {sample['purchase_intention']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Shut down server on each executor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-20 01:19:32,218 - INFO - Requesting stage-level resources: (cores=13, gpu=2.0)\n",
      "2025-03-20 01:19:33,872 - INFO - Successfully stopped 2 VLLM servers.           \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[True, True]"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "server_manager.stop_servers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n",
    "    spark.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spark-dl-vllm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
